
import os

import time

import numpy as np

import torchvision.utils as vutils


class Visualizer():

    def __init__(self, opt):

        # self.opt = opt

        self.display_id = opt.display_id

        self.win_size = 256

        self.name = opt.name

        self.opt = opt

        if self.opt.display:

            import visdom

            self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port)



        # --

        # Dictionaries for plotting data and results.

        self.plot_data = None

        self.plot_res = None



        # --

        # Path to train and test directories.

        self.img_dir = '/home/lab-lu.chengdong/Pictures/output/ganomaly/cifar10/train/'

        self.tst_img_dir = os.path.join(opt.outf, opt.name, 'test', 'images')

        if not os.path.exists(self.img_dir):

            os.makedirs(self.img_dir)

        if not os.path.exists(self.tst_img_dir):

            os.makedirs(self.tst_img_dir)

        # --

        # Log file.

        self.log_name = os.path.join(opt.outf, opt.name, 'loss_log.txt')

        with open(self.log_name, "a") as log_file:

            now = time.strftime("%c")

            log_file.write('================ Training Loss (%s) ================\n' % now)



    ##

    @staticmethod

    def normalize(inp):

        """Normalize the tensor



        Args:

            inp ([FloatTensor]): Input tensor



        Returns:

            [FloatTensor]: Normalized tensor.

        """

        return (inp - inp.min()) / (inp.max() - inp.min() + 1e-5)



    ##

    def plot_current_errors(self, epoch, counter_ratio, errors):

        """Plot current errros.



        Args:

            epoch (int): Current epoch

            counter_ratio (float): Ratio to plot the range between two epoch.

            errors (OrderedDict): Error for the current epoch.

        """



        if not hasattr(self, 'plot_data') or self.plot_data is None:

            self.plot_data = {'X': [], 'Y': [], 'legend': list(errors.keys())}

        self.plot_data['X'].append(epoch + counter_ratio)

        self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])

        self.vis.line(

            X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),

            Y=np.array(self.plot_data['Y']),

            opts={

                'title': self.name + ' loss over time',

                'legend': self.plot_data['legend'],

                'xlabel': 'Epoch',

                'ylabel': 'Loss'

            },

            win=4

        )



    ##

    def plot_performance(self, epoch, counter_ratio, performance):

        """ Plot performance



        Args:

            epoch (int): Current epoch

            counter_ratio (float): Ratio to plot the range between two epoch.

            performance (OrderedDict): Performance for the current epoch.

        """

        if not hasattr(self, 'plot_res') or self.plot_res is None:

            self.plot_res = {'X': [], 'Y': [], 'legend': list(performance.keys())}

        self.plot_res['X'].append(epoch + counter_ratio)

        self.plot_res['Y'].append([performance[k] for k in self.plot_res['legend']])

        self.vis.line(

            X=np.stack([np.array(self.plot_res['X'])] * len(self.plot_res['legend']), 1),

            Y=np.array(self.plot_res['Y']),

            opts={

                'title': self.name + 'Performance Metrics',

                'legend': self.plot_res['legend'],

                'xlabel': 'Epoch',

                'ylabel': 'Stats'

            },

            win=5

        )



    ##

    def print_current_errors(self, epoch, errors):

        """ Print current errors.



        Args:

            epoch (int): Current epoch.

            errors (OrderedDict): Error for the current epoch.

            batch_i (int): Current batch

            batch_n (int): Total Number of batches.

        """

        # message = '   [%d/%d] ' % (epoch, self.opt.niter)

        message = '   Loss: [%d/%d] ' % (epoch, self.opt.niter)

        for key, val in errors.items():

            message += '%s: %.3f ' % (key, val)



        print(message)

        with open(self.log_name, "a") as log_file:

            log_file.write('%s\n' % message)



    ##

    def print_current_performance(self, performance, best):

        """ Print current performance results.



        Args:

            performance ([OrderedDict]): Performance of the model

            best ([int]): Best performance.

        """

        message = '   '

        for key, val in performance.items():

            message += '%s: %.3f ' % (key, val)

        message += 'max AUC: %.3f' % best



        print(message)

        with open(self.log_name, "a") as log_file:

            log_file.write('%s\n' % message)



    def display_current_images(self, reals, fakes, fixed):

        """ Display current images.



        Args:

            epoch (int): Current epoch

            counter_ratio (float): Ratio to plot the range between two epoch.

            reals ([FloatTensor]): Real Image

            fakes ([FloatTensor]): Fake Image

            fixed ([FloatTensor]): Fixed Fake Image

        """

        reals = self.normalize(reals.cpu().numpy())

        fakes = self.normalize(fakes.cpu().numpy())

        fixed = self.normalize(fixed.cpu().numpy())



        self.vis.images(reals, win=1, opts={'title': 'Reals'})

        self.vis.images(fakes, win=2, opts={'title': 'Fakes'})

        self.vis.images(fixed, win=3, opts={'title': 'Fixed'})



    def save_current_images(self, epoch, reals, fakes, fixed):

        """ Save images for epoch i.



        Args:

            epoch ([int])        : Current epoch

            reals ([FloatTensor]): Real Image

            fakes ([FloatTensor]): Fake Image

            fixed ([FloatTensor]): Fixed Fake Image

        """

        vutils.save_image(reals, '%s/reals.png' % self.img_dir, normalize=True)

        vutils.save_image(fakes, '%s/fakes.png' % self.img_dir, normalize=True)

        vutils.save_image(fixed, '%s/fixed_fakes_%03d.png' %(self.img_dir, epoch+1), normalize=True)